import argparse
import json
import os
import pathlib
import typing
import warnings

import dotenv
import filelock
import numpy as np
import opacus
import opacus.utils.batch_memory_manager
import sklearn.metrics
import torch
import torch.utils.data
import torchdata.dataloader2
import torchdata.datapipes.map
import torchvision
import torchvision.transforms.v2
import tqdm

import base
import data
import dpsgd_utils
import networks


def main():
    dotenv.load_dotenv()
    args = parse_args()
    data_dir = args.data_path

    verbose = args.verbose

    global_seed = args.seed
    base.setup_seeds(global_seed)

    num_shadow = args.num_shadow
    assert num_shadow > 0
    num_canaries = args.num_canaries
    assert num_canaries > 0
    num_poison = args.num_poison
    assert num_poison >= 0

    data_generator = data.DatasetGenerator(
        num_shadow=num_shadow,
        num_canaries=num_canaries,
        canary_type=data.CanaryType(args.canary_type),
        num_poison=num_poison,
        poison_type=data.PoisonType(args.poison_type),
        data_dir=data_dir,
        seed=global_seed,
        download=True,
    )

    shadow_model_idx = args.exp_id
    assert 0 <= shadow_model_idx < num_shadow
    setting_seed = base.get_setting_seed(
        global_seed=global_seed, shadow_model_idx=shadow_model_idx, num_shadow=num_shadow
    )
    base.setup_seeds(setting_seed)

    noisy_targets_path = os.path.join(args.lira_path, 'noisy_targets.npy')
    canary_indices_path = os.path.join(args.lira_path, 'canary_indices.npy')
    indices_path = os.path.join(args.lira_path, 'indices')
    noisy_targets, _, _ = data_generator.build_attack_data()
    canary_indices = data_generator.get_canary_indices()
    shadow_in_indices = data_generator._shadow_in_indices
    os.makedirs(indices_path, exist_ok=True)
    if not os.path.exists(noisy_targets_path):
        np.save(noisy_targets_path, np.array(noisy_targets))
    else:
        assert np.array_equal(np.array(noisy_targets), np.load(noisy_targets_path))
    if not os.path.exists(canary_indices_path):
        np.save(canary_indices_path, np.array(canary_indices))
    else:
        assert np.array_equal(np.array(canary_indices), np.load(canary_indices_path))
    if not os.path.exists(os.path.join(indices_path, f'indice_{args.exp_id}.npy')):
        np.save(os.path.join(indices_path, f'indice_{args.exp_id}.npy'), np.array(shadow_in_indices[args.exp_id]))
    else:
        assert np.array_equal(np.array(shadow_in_indices[args.exp_id]), np.load(os.path.join(indices_path, f'indice_{args.exp_id}.npy')))

    output_dir = os.path.join(args.lira_path, "ckpts")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    _run_train(
        args,
        output_dir,
        shadow_model_idx,
        data_generator,
        setting_seed,
        verbose,
    )


def _run_train(
    args: argparse.Namespace,
    output_dir: str,
    shadow_model_idx: int,
    data_generator: data.DatasetGenerator,
    training_seed: int,
    verbose: bool,
) -> None:
    # Hyperparameters
    num_epochs = args.num_epochs
    noise_multiplier = args.noise_multiplier
    max_grad_norm = args.max_grad_norm
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    augmult_factor = args.augmult_factor

    print(f"Training shadow model {shadow_model_idx}")
    print(
        f"{data_generator.num_canaries} canaries ({data_generator.canary_type.value}), "
        f"{data_generator.num_poison} poisons ({data_generator.poison_type.value})"
    )

    train_data = data_generator.build_train_data(
        shadow_model_idx=shadow_model_idx,
    )

    current_model = _train_model(
        train_data,
        data_generator=data_generator,
        training_seed=training_seed,
        num_epochs=num_epochs,
        noise_multiplier=noise_multiplier,
        max_grad_norm=max_grad_norm,
        learning_rate=learning_rate,
        batch_size=batch_size,
        augmult_factor=augmult_factor,
        verbose=verbose,
    )
    current_model.eval()

    torch.save(
        current_model, 
        os.path.join(output_dir, f"model_last_{shadow_model_idx}.pt")
    )
    print("Saved model")


def _train_model(
    train_data: data.Dataset,
    data_generator: data.DatasetGenerator,
    training_seed: int,
    num_epochs: int,
    noise_multiplier: float,
    max_grad_norm: float,
    learning_rate: float,
    batch_size: int,
    augmult_factor: int,
    verbose: bool = False,
) -> torch.nn.Module:
    momentum = 0
    weight_decay = 0

    model = _build_model(num_classes=10)
    model.train()

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=UserWarning, message="Secure RNG turned off", module="dpsgd_utils")
        privacy_engine = dpsgd_utils.PrivacyEngineAugmented(opacus.GradSampleModule.GRAD_SAMPLERS)

    train_datapipe = train_data.build_map_datapipe()

    # FIXME: Theoretically, the use of global mean-std normalization incurs a privacy cost that is unaccounted.
    #  However, the normalization constants are w.r.t. the full CIFAR10 data
    #  and I guess glancing over that fact here is not a big issue.
    # NB: This always aplies normalization before potential data augmentation
    train_datapipe = train_datapipe.map(
        torchvision.transforms.v2.Compose(
            [
                torchvision.transforms.v2.ConvertDtype(base.DTYPE),
                torchvision.transforms.v2.Normalize(
                    mean=data.CIFAR10_MEAN,
                    std=data.CIFAR10_STD,
                ),
            ]
        ),
    )
    loss = torch.nn.CrossEntropyLoss(reduction="mean")
    train_loader = torch.utils.data.DataLoader(
        train_datapipe,
        drop_last=False,
        num_workers=4,
        batch_size=batch_size,
        shuffle=True,
    )
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

    dp_delta = 1e-5
    model, optimizer, train_loader = privacy_engine.make_private(
        module=model,
        optimizer=optimizer,
        data_loader=train_loader,
        noise_multiplier=noise_multiplier,
        max_grad_norm=max_grad_norm,
        poisson_sampling=True,  # guarantees proper privacy, but could also disable for simplicity
        K=augmult_factor,
        loss_reduction=loss.reduction,
    )

    # Prepare augmult and exponential moving average as in tan code
    # Need to always use custom hooks, else the modified privacy engine will not work with augmult 0
    augmentation = dpsgd_utils.AugmentationMultiplicity(augmult_factor if augmult_factor > 0 else 1)
    model.GRAD_SAMPLERS[torch.nn.modules.conv.Conv2d] = augmentation.augmented_compute_conv_grad_sample
    model.GRAD_SAMPLERS[torch.nn.modules.linear.Linear] = augmentation.augmented_compute_linear_grad_sample
    model.GRAD_SAMPLERS[torch.nn.GroupNorm] = augmentation.augmented_compute_group_norm_grad_sample
    ema_model = dpsgd_utils.create_ema(model)

    num_true_updates = 0
    for epoch in (pbar := tqdm.trange(num_epochs, desc="Training", unit="epoch")):
        model.train()
        with opacus.utils.batch_memory_manager.BatchMemoryManager(
            data_loader=train_loader,
            # 64 is max that always fits into 3090 VRAM w/o memory fragmentation issues for 16 augmult factor
            # Using 128 does not seem to significantly improve performance
            max_physical_batch_size=64,
            optimizer=optimizer,
        ) as optimized_train_loader:
            num_samples = 0
            epoch_loss = 0.0
            epoch_accuracy = 0.0
            for batch_xs, batch_ys in tqdm.tqdm(
                optimized_train_loader,
                desc="Current epoch",
                unit="batch",
                leave=False,
                disable=not verbose,
            ):
                batch_xs = batch_xs.to(base.DEVICE)
                batch_ys = batch_ys.to(base.DEVICE)
                original_batch_size = batch_xs.size(0)

                optimizer.zero_grad(set_to_none=True)
                if augmult_factor > 0:
                    # Taken from tan code:
                    batch_xs = torch.repeat_interleave(batch_xs, repeats=augmult_factor, dim=0)
                    batch_ys = torch.repeat_interleave(batch_ys, repeats=augmult_factor, dim=0)
                    transform = torchvision.transforms.v2.Compose(
                        [
                            torchvision.transforms.v2.RandomCrop(32, padding=4),
                            torchvision.transforms.v2.RandomHorizontalFlip(),
                        ]
                    )
                    batch_xs = torchvision.transforms.v2.Lambda(lambda x: torch.stack([transform(x_) for x_ in x]))(
                        batch_xs
                    )
                    assert batch_xs.size(0) == augmult_factor * original_batch_size

                batch_pred = model(batch_xs)
                batch_loss = loss(input=batch_pred, target=batch_ys)
                batch_loss.backward()

                # Only update ema at the end of actual batches
                will_update = not optimizer._check_skip_next_step(pop_next=False)
                optimizer.step()
                if will_update:
                    num_true_updates += 1
                    dpsgd_utils.update_ema(model, ema_model, num_true_updates)

                batch_unnormalized_accuracy = (batch_pred.detach().argmax(-1) == batch_ys).int().sum().item()
                epoch_loss += batch_loss.item() * batch_xs.size(0)
                epoch_accuracy += batch_unnormalized_accuracy
                num_samples += batch_xs.size(0)
            epoch_loss /= num_samples
            epoch_accuracy /= num_samples
            dp_eps = privacy_engine.get_epsilon(dp_delta)
            progress_dict = {
                "epoch_loss": epoch_loss,
                "epoch_accuracy": epoch_accuracy,
                "dp_eps": dp_eps,
                "dp_delta": dp_delta,
                "update_steps": num_true_updates,
            }

            pbar.set_postfix(progress_dict)

    ema_model.eval()
    return ema_model


def _build_model(num_classes: int) -> networks.ResNet18:
    return networks.ResNet18(
        channel=3,
        num_classes=num_classes
    ).to(base.DEVICE)


def parse_args() -> argparse.Namespace:

    parser = argparse.ArgumentParser()

    # General args
    parser.add_argument("--data_path", type=str, required=True, help="Dataset root directory")
    parser.add_argument('--lira_path', type=str, required=True, help='path to save LiRA files, e.g., in-out-split indices, canaries indices, and noise targets')
    parser.add_argument("--verbose", action="store_true")

    # Dataset and setup args
    parser.add_argument("--seed", required=True, type=int)
    parser.add_argument("--num-shadow", type=int, default=64, help="Number of shadow models")
    parser.add_argument("--num-canaries", type=int, default=500, help="Number of canaries to audit")
    parser.add_argument(
        "--canary-type",
        type=data.CanaryType,
        default=data.CanaryType.CLEAN,
        choices=list(data.CanaryType),
        help="Type of canary to use",
    )
    parser.add_argument("--num-poison", type=int, default=0, help="Number of poison samples to include")
    parser.add_argument(
        "--poison-type",
        type=data.PoisonType,
        default=data.PoisonType.CANARY_DUPLICATES,
        choices=list(data.PoisonType),
        help="Type of poisoning to use",
    )

    parser.add_argument(
        "--exp_id", type=int, required=True, help="Train shadow model with index if present"
    )

    # Defense-specific
    parser.add_argument("--num-epochs", type=int, default=200, help="Number of training epochs")
    parser.add_argument("--noise-multiplier", type=float, default=0.2, help="Gaussian noise multiplier")
    parser.add_argument("--max-grad-norm", type=float, default=1.0, help="Gradient clipping norm")
    parser.add_argument("--learning-rate", type=float, default=4.0, help="Learning rate")
    parser.add_argument("--batch-size", type=int, default=2048, help="Virtual batch size")
    parser.add_argument("--augmult-factor", type=int, default=8, help="Number of data augmentations per sample")

    return parser.parse_args()


if __name__ == "__main__":
    main()
